package com.spiro.test.mr;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLClassLoader;

/**
 * @Author: Shaoping Huang
 * @Description:
 * @Date: 1/4/2018
 */
public class MapReduceClassLoader extends URLClassLoader {

    private static Logger LOG = LoggerFactory.getLogger(MapReduceClassLoader.class);

    public MapReduceClassLoader() {
        super(new URL[]{});
    }

    public MapReduceClassLoader(URL[] urls, ClassLoader parent) {
        super(urls, parent);
    }

    public synchronized Class<?> loadClass(String name) throws ClassNotFoundException {

        // 判断类是否已被加载，如果已加载则直接返回
        Class c = this.findLoadedClass(name);
        if (c != null) {
            return c;
        }

        ClassNotFoundException ex = null;

        // 如果待加载的类为Job，则进行字节码转换后再加载类
        if (name.equals("org.apache.hadoop.mapreduce.Job")) {
            byte[] bytes = transformJobBytecode(name);
            if (bytes == null) {
                ex = new ClassNotFoundException("Transform job bytecode failed.");
            } else {
                c = defineClass(name, bytes, 0, bytes.length);
            }
        }

        if (c == null) {
            // 如果待加载的是JDK提供的系统类，则由父类加载器去完成，这里的父类加载器是sun.misc.Launcher.AppClassLoader
            if (ClassPathUtils.isSystemClass(name)) {
                try {
                    c = this.getParent().loadClass(name);
                } catch (ClassNotFoundException e) {
                    ex = e;
                }
            }

            // 当前类加载器来进行加载
            if (c == null) {
                try {
                    c = findClass(name);
                } catch (Exception e) {
                    ex = new ClassNotFoundException(e.getMessage());
                }
            }

            // 当前类加载器加载不到，尝试由父类加载器来完成
            if (c == null && this.getParent() != null) {
                try {
                    c = this.getParent().loadClass(name);
                } catch (ClassNotFoundException e) {
                    ex = e;
                }
            }
        }

        if (c == null) {
            throw ex;
        } else {
            LOG.info("loaded " + c + " from " + c.getClassLoader());
            return c;
        }
    }

    /**
     * 添加由该类加载的classpath
     * @param classPath
     */
    public void addClassPath(String classPath) {
        URL[] cpUrls = ClassPathUtils.getClassPathURLs(classPath);
        for (URL cpUrl : cpUrls) {
            addURL(cpUrl);
        }
    }

    /**
     * 转换Job类的字节码
     * @param jobClassName
     * @return
     */
    private byte[] transformJobBytecode(String jobClassName) {
        String path = jobClassName.replace('.', '/').concat(".class");
        InputStream is = getResourceAsStream(path);
        if (is == null) {
            return null;
        }

        try {
            byte[] b = getBytes(is);

            ClassReader cr = new ClassReader(b);
            ClassWriter cw = new ClassWriter(cr, 0);
            cr.accept(new JobAdapter(cw), 0);
            return cw.toByteArray();
        } catch (IOException e) {
        }

        return null;
    }

    /**
     * 从流对象中过去字节码
     * @param is
     * @return
     * @throws IOException
     */
    private byte[] getBytes(InputStream is) throws IOException {
        try {
            int available = is.available();
            byte[] bytes = new byte[available];
            int pos = 0;

            byte[] buf = new byte[1024];
            int len;
            while ((len = is.read(buf)) != -1) {
                System.arraycopy(buf, 0, bytes, pos, len);
                pos += len;
                if (pos >= available) {
                    break;
                }
            }

            return bytes;
        } finally {
            if (is != null) {
                try {
                    is.close();
                } catch (IOException e) {
                }
            }
        }
    }



}
